#!/usr/bin/env python3

#####################################################################
#
# pfcstat is a tool for summarizing Priority-based Flow Control (PFC) statistics.
#
#####################################################################

import json
import argparse
import datetime
import os.path
import sys

from collections import namedtuple, OrderedDict
from copy import deepcopy
from natsort import natsorted
from tabulate import tabulate

from sonic_py_common.multi_asic import get_external_ports

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()

except KeyError:
    pass

from utilities_common.netstat import ns_diff, STATUS_NA, format_number_with_comma, format_microseconds_as_datetime
from utilities_common import multi_asic as multi_asic_util
from utilities_common import constants
from utilities_common.cli import json_serial, UserCache


PStats = namedtuple("PStats", "pfc0, pfc1, pfc2, pfc3, pfc4, pfc5, pfc6, pfc7")

pfc_titles = ['PFC0', 'PFC1', 'PFC2', 'PFC3', 'PFC4', 'PFC5', 'PFC6', 'PFC7']

header_Rx = ['Port Rx'] + pfc_titles
header_Tx = ['Port Tx'] + pfc_titles

counter_bucket_rx_dict = {
    'SAI_PORT_STAT_PFC_0_RX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_RX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_RX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_RX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_RX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_RX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_RX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_RX_PKTS': 7
}

counter_bucket_tx_dict = {
    'SAI_PORT_STAT_PFC_0_TX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_TX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_TX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_TX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_TX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_TX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_TX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_TX_PKTS': 7
}


HistStats = namedtuple("HistStats", "numTransitions, totalPauseTime, recentPauseTimestamp, recentPauseTime")

SAI_PREFIX = "SAI"
EST_PREFIX = "EST"

total_stat_fields = [
    "_PORT_STAT_PFC_*_ON2OFF_RX_PKTS",
    "_PORT_STAT_PFC_*_RX_PAUSE_DURATION_US"
]

recent_stat_fields = [
    "EST_PORT_STAT_PFC_*_RECENT_PAUSE_TIMESTAMP",
    "EST_PORT_STAT_PFC_*_RECENT_PAUSE_TIME_US"
]

header_hist = ['Port', 'Priority',
               'RX Pause Transitions',
               'Total RX Pause Time US',
               'Recent RX Pause Time US',
               'Recent RX Pause Timestamp']

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"

class Pfcstat(object):
    def __init__(self, namespace, display):
        self.multi_asic = multi_asic_util.MultiAsic(display, namespace)
        self.db = None
        self.config_db = None
        self.cnstat_dict = OrderedDict()
        self.hist_dict = OrderedDict()

    @multi_asic_util.run_on_multi_asic
    def collect_cnstat(self, rx):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = ["0","0","0","0","0","0","0","0"]
            if rx:
                bucket_dict = counter_bucket_rx_dict
            else:
                bucket_dict = counter_bucket_tx_dict
            for counter_name, pos in bucket_dict.items():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(
                    self.db.COUNTERS_DB, full_table_id, counter_name
                )
                if counter_data is None:
                    fields[pos] = STATUS_NA
                else:
                    fields[pos] = str(int(counter_data))
            cntr = PStats._make(fields)._asdict()
            return cntr

        # Get the info from database
        counter_port_name_map = self.db.get_all(
            self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP
        )
        if counter_port_name_map is None:
            return
        display_ports_set = set(counter_port_name_map.keys())
        if self.multi_asic.display_option == constants.DISPLAY_EXTERNAL:
            display_ports_set = get_external_ports(
                display_ports_set, self.multi_asic.current_namespace
            )
        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if counter_port_name_map is not None:
            for port in natsorted(counter_port_name_map):
                if port in display_ports_set:
                    cnstat_dict[port] = get_counters(
                        counter_port_name_map[port]
                    )
            self.cnstat_dict.update(cnstat_dict)

    @multi_asic_util.run_on_multi_asic
    def collect_history(self):
        """
            Get the history stats from database.
        """
        def get_history_stats(table_id):
            """
                Get the history from specific table.
            """
            pfc_dict = OrderedDict()
            full_table_id = COUNTER_TABLE_PREFIX + table_id

            for pfc_index, pfc in enumerate(pfc_titles):
                def get_stat(stat_name, prefix = ""):
                    full_stat_name = prefix + stat_name.replace('*', str(pfc_index))
                    stat =  self.db.get(
                        self.db.COUNTERS_DB, full_table_id, full_stat_name
                    )
                    return stat

                fields = [STATUS_NA] * (len(total_stat_fields) + len(recent_stat_fields))
                stat_index = 0
                # SAI or EST options
                for stat_name in total_stat_fields:
                    hist_data = get_stat(stat_name, SAI_PREFIX)

                    if hist_data is None:
                        hist_data = get_stat(stat_name, EST_PREFIX)

                    if hist_data is not None:
                        fields[stat_index] = str(hist_data)
                    stat_index += 1

                # EST options only
                for stat_name in recent_stat_fields:
                    hist_data = get_stat(stat_name)

                    if hist_data is not None:
                        fields[stat_index] = str(hist_data)
                    stat_index += 1

                stats = HistStats._make(fields)._asdict()
                pfc_dict[pfc] = stats
            return pfc_dict

        # get the port name : oid map
        counter_port_name_map = self.db.get_all(
            self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP
        )
        if counter_port_name_map is None:
            return

        display_ports_set = set(counter_port_name_map.keys())
        if self.multi_asic.display_option == constants.DISPLAY_EXTERNAL:
            display_ports_set = get_external_ports(
                display_ports_set, self.multi_asic.current_namespace
            )
        # Build a dictionary of the stats
        hist_dict = OrderedDict()
        hist_dict['time'] = datetime.datetime.now()
        if counter_port_name_map is not None:
            for port in natsorted(counter_port_name_map):
                if port in display_ports_set:
                    hist_dict[port] = get_history_stats(
                        counter_port_name_map[port]
                    )
            self.hist_dict.update(hist_dict)

    def get_cnstat(self, rx):
        """
            Get the counters info from database.
        """
        self.cnstat_dict.clear()
        self.collect_cnstat(rx)
        return self.cnstat_dict

    def get_history(self):
        """
            Get the history stats from database. These values are populated by the pfcwd detect script.
        """
        self.hist_dict.clear()
        self.collect_history()
        return self.hist_dict

    def cnstat_print(self, cnstat_dict, rx):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.items():
            if key == 'time':
                continue
            table.append((key,
                          format_number_with_comma(data['pfc0']),
                          format_number_with_comma(data['pfc1']),
                          format_number_with_comma(data['pfc2']),
                          format_number_with_comma(data['pfc3']),
                          format_number_with_comma(data['pfc4']),
                          format_number_with_comma(data['pfc5']),
                          format_number_with_comma(data['pfc6']),
                          format_number_with_comma(data['pfc7'])))

        if rx:
            print(tabulate(table, header_Rx, tablefmt='simple', stralign='right'))
        else:
            print(tabulate(table, header_Tx, tablefmt='simple', stralign='right'))

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, rx):
        """
            Print the difference between two cnstat results.
        """
        table = []

        for key, cntr in cnstat_new_dict.items():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr['pfc0'], old_cntr['pfc0']),
                            ns_diff(cntr['pfc1'], old_cntr['pfc1']),
                            ns_diff(cntr['pfc2'], old_cntr['pfc2']),
                            ns_diff(cntr['pfc3'], old_cntr['pfc3']),
                            ns_diff(cntr['pfc4'], old_cntr['pfc4']),
                            ns_diff(cntr['pfc5'], old_cntr['pfc5']),
                            ns_diff(cntr['pfc6'], old_cntr['pfc6']),
                            ns_diff(cntr['pfc7'], old_cntr['pfc7'])))
            else:
                table.append((key,
                              format_number_with_comma(cntr['pfc0']),
                              format_number_with_comma(cntr['pfc1']),
                              format_number_with_comma(cntr['pfc2']),
                              format_number_with_comma(cntr['pfc3']),
                              format_number_with_comma(cntr['pfc4']),
                              format_number_with_comma(cntr['pfc5']),
                              format_number_with_comma(cntr['pfc6']),
                              format_number_with_comma(cntr['pfc7'])))

        if rx:
            print(tabulate(table, header_Rx, tablefmt='simple', stralign='right'))
        else:
            print(tabulate(table, header_Tx, tablefmt='simple', stralign='right'))

    def history_diff_print(self, headers, hist_new_dict, hist_old_dict={}, ):
        """
            Print the difference between two cnstat history results.
        """
        table = []
        # time : <>, ethernet0 : { pfc0: {}, pfc1: {} }, ethernet1 :{}
        for key, pfc_dict in hist_new_dict.items():
            if key == 'time':
                continue
            # old dict has this port
            if key in hist_old_dict:
                for pfc, stat in pfc_dict.items():
                    old_stat = None
                    if pfc in hist_old_dict.get(key):
                        old_stat = hist_old_dict.get(key).get(pfc)
                    # old dict [ port ] has this priority data
                    if old_stat is not None:
                        table.append((key,
                                    pfc,
                                    ns_diff(stat['numTransitions'], old_stat['numTransitions']),
                                    ns_diff(stat['totalPauseTime'], old_stat['totalPauseTime']),
                                    format_number_with_comma(stat['recentPauseTime']),
                                    format_microseconds_as_datetime(stat['recentPauseTimestamp'])))
                    # use the new data only for this port for this priority
                    else:
                        table.append((key,
                                    pfc,
                                    format_number_with_comma(stat['numTransitions']),
                                    format_number_with_comma(stat['totalPauseTime']),
                                    format_number_with_comma(stat['recentPauseTime']),
                                    format_microseconds_as_datetime(stat['recentPauseTimestamp'])))
            # use the new data only for this port
            else:
                for pfc, stat in pfc_dict.items():
                    table.append((key,
                                pfc,
                                format_number_with_comma(stat['numTransitions']),
                                format_number_with_comma(stat['totalPauseTime']),
                                format_number_with_comma(stat['recentPauseTime']),
                                format_microseconds_as_datetime(stat['recentPauseTimestamp'])))
            # empty line between interfaces
            table.append([])

        print(tabulate(table, headers, tablefmt='simple', stralign='right'))

def main():
    parser  = argparse.ArgumentParser(description='Display the pfc counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  pfcstat
  pfcstat -c
  pfcstat -d
  pfcstat -n asic1
  pfcstat -s all -n asic0
  pfcstat --history
  pfcstat -n asic1 --history
  pfcstat -s all -n asic0 --history
""")

    parser.add_argument( '-c', '--clear', action='store_true',
        help='Clear previous stats and save new ones'
    )
    parser.add_argument(
        '-d', '--delete', action='store_true', help='Delete saved stats'
    )
    parser.add_argument('-s', '--show', default=constants.DISPLAY_EXTERNAL,
        help='Display all interfaces or only external interfaces'
    )
    parser.add_argument('-n', '--namespace', default=None,
        help='Display interfaces for specific namespace'
    )
    parser.add_argument(
        '-v', '--version', action='version', version='%(prog)s 1.0'
    )
    parser.add_argument(
        '--history', action="store_true", help="Display historical PFC statistics, must be supported by the platform-specific PFCWD detect script."
    )

    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_all_stats = args.delete
    show_pfc_history = args.history

    cache = UserCache()
    cnstat_file = 'pfcstat'

    cnstat_dir = cache.get_directory()
    cnstat_fqn_file_rx = os.path.join(cnstat_dir, "{}rx".format(cnstat_file))
    cnstat_fqn_file_tx = os.path.join(cnstat_dir, "{}tx".format(cnstat_file))
    hist_fqn_file = cnstat_fqn_file_rx + "_hist"

    # if '-c' option is provided get stats from all (frontend and backend) ports
    if save_fresh_stats:
        args.namespace = None
        args.show = constants.DISPLAY_ALL

    pfcstat = Pfcstat(args.namespace, args.show)

    if delete_all_stats:
        cache.remove()

    # clear both counters + history
    if save_fresh_stats:
        # Get the counters of pfc rx counter
        cnstat_dict_rx = deepcopy(pfcstat.get_cnstat(True))
        # Get the counters of pfc tx counter
        cnstat_dict_tx = deepcopy(pfcstat.get_cnstat(False))
        # Get the history stats of pfc rx
        hist_dict = deepcopy(pfcstat.get_history())

        try:
            json.dump(cnstat_dict_rx, open(cnstat_fqn_file_rx, 'w'), default=json_serial)
            json.dump(cnstat_dict_tx, open(cnstat_fqn_file_tx, 'w'), default=json_serial)
            json.dump(hist_dict, open(hist_fqn_file, 'w'), default=json_serial)
        except IOError as e:
            print(e.errno, e)
            sys.exit(e.errno)
        else:
            print("Clear saved PFC counters and history")
            sys.exit(0)

    # save history stats separately
    if show_pfc_history:
        """
            Get the history stats of pfc rx
        """
        hist_dict = deepcopy(pfcstat.get_history())

        """
            Print the pfc history stats
        """
        if os.path.isfile(hist_fqn_file):
            try:
                hist_cached_dict = json.load(open(hist_fqn_file, 'r'))
                print("Last cached time was " + str(hist_cached_dict.get('time')))
                pfcstat.history_diff_print(header_hist, hist_dict, hist_cached_dict)
            except IOError as e:
                print(e.errno, e)
                sys.exit(e.errno)
        else:
            pfcstat.history_diff_print(header_hist, hist_dict)

    else:
        """
            Get the counters of pfc rx counter
        """
        cnstat_dict_rx = deepcopy(pfcstat.get_cnstat(True))

        """
            Get the counters of pfc tx counter
        """
        cnstat_dict_tx = deepcopy(pfcstat.get_cnstat(False))

        """
            Print the counters of pfc rx counter
        """
        if os.path.isfile(cnstat_fqn_file_rx):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_rx, 'r'))
                print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                pfcstat.cnstat_diff_print(cnstat_dict_rx, cnstat_cached_dict, True)
            except IOError as e:
                print(e.errno, e)
                sys.exit(e.errno)
        else:
            pfcstat.cnstat_print(cnstat_dict_rx, True)

        print("")

        """
            Print the counters of pfc tx counter
        """
        if os.path.isfile(cnstat_fqn_file_tx):
            try:
                cnstat_cached_dict = json.load(open(cnstat_fqn_file_tx, 'r'))
                print("Last cached time was " + str(cnstat_cached_dict.get('time')))
                pfcstat.cnstat_diff_print(cnstat_dict_tx, cnstat_cached_dict, False)
            except IOError as e:
                print(e.errno, e)
                sys.exit(e.errno)
        else:
            pfcstat.cnstat_print(cnstat_dict_tx, False)

    sys.exit(0)

if __name__ == "__main__":
    main()
